use super::*;
use crate::{tensor_ops::*, tests::*};

#[test]
fn test_convtrans2d_default() {
    let device = TestDevice::default();

    let x = device
        .tensor([
            [
                [-1.44135797, 0.23273671, 0.55838293, 1.09627271],
                [0.05642751, 0.96609902, 0.24707083, -1.48412001],
                [0.40077326, 1.16620362, 0.48770329, -1.01286852],
            ],
            [
                [0.32589695, -0.91695106, -0.13670059, 0.23979346],
                [0.72553939, -0.38209674, 0.35545620, -0.37955058],
                [-0.43962145, -0.69825196, -1.05400932, 1.22050178],
            ],
        ])
        .to_dtype::<TestDtype>();
    let w = device
        .tensor([
            [
                [[-0.46923456, -0.05046696], [0.30552489, 0.68838942]],
                [[-0.45128539, 0.28081128], [-1.64214528, -0.67228431]],
                [[0.82240576, 0.06074515], [-0.19825625, -0.28086993]],
            ],
            [
                [[-1.61626124, 0.44280547], [0.03486229, 0.80007958]],
                [[-0.69502026, 0.22409980], [0.58573663, -0.36501792]],
                [[1.57198501, 0.84847665], [0.52653229, -0.59273601]],
            ],
        ])
        .to_dtype::<TestDtype>();
    let y =
        (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<1>);
    #[rustfmt::skip]
    assert_close_to_literal!(
        y,
        [
            [
                [0.14960037, 1.58987427, -0.45884517, -0.99068964, 0.05085630],
                [-1.62814808,-0.20966601,-1.31598091,2.07309437,0.85334837],
                [0.56502044, 1.26762176, 1.55388546, -2.00090408, -0.73376185],
                [0.10711999, 0.25611749, 0.35640538, -0.77446860, 0.27925056]
            ],
            [
                [0.42395884, 0.20055276, -0.29711384, -0.53522754, 0.36158341],
                [2.02807951,-0.06121790,-0.99166316,-0.90268230,-1.32635069],
                [0.45699552, -2.14002204, -0.02407408, 1.42854285, 1.12538254],
                [-0.91563028,-2.43303156,-1.94739747,2.43502688,0.23543060]
            ],
            [
                [-0.67307591,-1.06106889,-0.51954788,1.19646454,0.27005240],
                [1.64429677, 0.49562341, 0.79191101, -1.66748166, -0.86223716],
                [0.00935271, -1.32583618, -1.68409586, 0.03524891, 1.61585832],
                [-0.31093070,-0.45084506,-0.56533259,1.33120918,-0.43895102]
            ]
        ]
    );

    let g = y.exp().mean().backward();

    assert_close_to_literal!(
        g.get(&x),
        [
            [
                [-0.24465635, -0.07769803, 0.07309557, 0.09593978],
                [0.01824830, 0.06927966, -0.03252371, -0.23244604],
                [-0.01082366, -0.01004899, -0.16018088, -0.32376897]
            ],
            [
                [0.11863519, -0.11774766, 0.16648589, 0.13729726],
                [0.14436747, 0.10569359, 0.08925933, -0.20175745],
                [0.01255126, -0.02915864, -0.19611855, 0.20987590]
            ]
        ]
    );
    assert_close_to_literal!(
        g.get(&w),
        [
            [
                [[-0.06153677, -0.00425007], [0.25752968, 0.18442926]],
                [[-0.05023063, 0.00882470], [-0.45103958, 0.01904970]],
                [[0.08672050, 0.00499410], [-0.15565623, -0.10318264]]
            ],
            [
                [[-0.25463796, -0.01300730], [0.00675005, -0.00511352]],
                [[0.13059653, -0.01341964], [0.25117764, -0.17709662]],
                [[0.08071060, 0.07478670], [0.05806271, -0.11189780]]
            ]
        ]
    );
}

#[test]
fn test_convtrans2d_stride_2() {
    let device = TestDevice::default();
    let x = device
        .tensor([
            [
                [-1.44135797, 0.23273671, 0.55838293, 1.09627271],
                [0.05642751, 0.96609902, 0.24707083, -1.48412001],
                [0.40077326, 1.16620362, 0.48770329, -1.01286852],
            ],
            [
                [0.32589695, -0.91695106, -0.13670059, 0.23979346],
                [0.72553939, -0.38209674, 0.35545620, -0.37955058],
                [-0.43962145, -0.69825196, -1.05400932, 1.22050178],
            ],
        ])
        .to_dtype::<TestDtype>();
    let w = device
        .tensor([
            [
                [[-0.46923456, -0.05046696], [0.30552489, 0.68838942]],
                [[-0.45128539, 0.28081128], [-1.64214528, -0.67228431]],
                [[0.82240576, 0.06074515], [-0.19825625, -0.28086993]],
            ],
            [
                [[-1.61626124, 0.44280547], [0.03486229, 0.80007958]],
                [[-0.69502026, 0.22409980], [0.58573663, -0.36501792]],
                [[1.57198501, 0.84847665], [0.52653229, -0.59273601]],
            ],
        ])
        .to_dtype::<TestDtype>();
    let y =
        (x.leaky_trace(), w.clone()).convtrans2d(Const::<2>, Const::<0>, Const::<1>, Const::<1>);
    #[rustfmt::skip]
    assert_close_to_literal!(
        y,
        [
            [
                [0.14960037,0.21704991,1.37282431,-0.41777647,-0.04106871,-0.08871166,-0.90197796,0.05085630],
                [-0.42900923,-0.73147207,0.03913984,-0.57342035,0.16583419,0.27501357,0.34329835,0.94651639],
                [-1.19913888,0.31842509,0.16424109,-0.21795061,-0.69044423,0.14492904,1.30985332,-0.09316804],
                [0.04253398,0.61933333,0.28184652,0.35934454,0.08787831,0.45447421,-0.46666759,-1.32532310],
                [0.52248645,-0.21489260,0.58133453,-0.36804456,1.47470713,-0.49133399,-1.49737680,0.59156126],
                [0.10711999,-0.07584409,0.33196157,0.24414507,0.11226030,-0.50756156,-0.26690704,0.27925056]
            ],
            [
                [0.42395884,-0.33171612,0.53226888,-0.14013346,-0.15698038,0.12616566,-0.66139317,0.36158341],
                [2.55780911,0.85004413,-0.91927934,0.17823833,-0.99701643,-0.32549390,-1.65978324,-0.82453585],
                [-0.52972949,0.17843871,-0.17042139,0.18566369,-0.35854873,0.14903794,0.93355697,-0.50181484],
                [0.33231282,-0.30277020,-1.81028295,-0.51002109,-0.19752248,-0.29584974,2.21482396,1.13629329],
                [0.12468270,0.01402257,-0.04099140,0.17100500,0.51246446,-0.09925070,-0.39118069,-0.01091070],
                [-0.91563028,-0.10896385,-2.32406759,-0.52914590,-1.41825151,0.05685703,2.37816978,0.23543060]
            ],
            [
                [-0.67307591,0.18896043,-1.25002933,-0.76387393,0.24432607,-0.08206820,1.27853274,0.27005240],
                [0.45735350,0.21166326,-0.52894586,0.47814116,-0.18268019,-0.07580562,-0.09108391,-0.45004424],
                [1.18694329,0.61903095,0.19387504,-0.26551431,0.76196432,0.31660464,-1.81719673,-0.41219291],
                [0.37083280,-0.44590211,-0.39272144,-0.04486568,0.13817583,-0.28008646,0.09439043,0.64181799],
                [-0.36148009,-0.34866351,-0.13854906,-0.52160925,-1.25579679,-0.86467665,1.08562160,0.97404039],
                [-0.31093070,0.14801431,-0.59885937,0.08632754,-0.65166014,0.48776808,0.84344113,-0.43895102]
            ]
        ]
    );

    let g = y.exp().mean().backward();

    assert_close_to_literal!(
        g.get(&x),
        [
            [
                [-0.16320729, -0.02408358, 0.00202971, 0.02914695],
                [0.00702363, 0.00503576, 0.00480992, -0.13959591],
                [-0.00830228, 0.00076824, -0.02246094, -0.10893497]
            ],
            [
                [0.04730723, -0.04577319, 0.01554872, 0.06038214],
                [0.06361489, 0.01077521, 0.03939442, -0.01726136],
                [-0.00589108, -0.00387334, -0.05110299, 0.10329218]
            ]
        ]
    );
    assert_close_to_literal!(
        g.get(&w),
        [
            [
                [[0.00466055, -0.00407176], [0.02797192, 0.03677537]],
                [[-0.01260271, 0.02238824], [-0.29015976, -0.03991098]],
                [[0.02820409, -0.00682320], [-0.01703458, 0.00010475]]
            ],
            [
                [[-0.07678317, 0.00976532], [-0.01299408, 0.00706887]],
                [[-0.02618737, -0.00559338], [0.09905507, -0.00913252]],
                [[0.03891469, 0.02437293], [0.01698377, -0.02502952]]
            ]
        ]
    );
}

#[test]
fn test_convtrans2d_padded() {
    let device = TestDevice::default();
    let x = device
        .tensor([
            [
                [-1.44135797, 0.23273671, 0.55838293, 1.09627271],
                [0.05642751, 0.96609902, 0.24707083, -1.48412001],
                [0.40077326, 1.16620362, 0.48770329, -1.01286852],
            ],
            [
                [0.32589695, -0.91695106, -0.13670059, 0.23979346],
                [0.72553939, -0.38209674, 0.35545620, -0.37955058],
                [-0.43962145, -0.69825196, -1.05400932, 1.22050178],
            ],
        ])
        .to_dtype::<TestDtype>();
    let w = device
        .tensor([
            [
                [[-0.46923456, -0.05046696], [0.30552489, 0.68838942]],
                [[-0.45128539, 0.28081128], [-1.64214528, -0.67228431]],
                [[0.82240576, 0.06074515], [-0.19825625, -0.28086993]],
            ],
            [
                [[-1.61626124, 0.44280547], [0.03486229, 0.80007958]],
                [[-0.69502026, 0.22409980], [0.58573663, -0.36501792]],
                [[1.57198501, 0.84847665], [0.52653229, -0.59273601]],
            ],
        ])
        .to_dtype::<TestDtype>();
    let y =
        (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<1>, Const::<1>, Const::<1>);
    assert_close_to_literal!(
        y,
        [
            [
                [-0.20966601, -1.31598091, 2.07309437],
                [1.26762176, 1.55388546, -2.00090408]
            ],
            [
                [-0.06121790, -0.99166316, -0.90268230],
                [-2.14002204, -0.02407408, 1.42854285]
            ],
            [
                [0.49562341, 0.79191101, -1.66748166],
                [-1.32583618, -1.68409586, 0.03524891]
            ]
        ]
    );

    let g = y.exp().mean().backward();

    assert_close_to_literal!(
        g.get(&x),
        [
            [
                [-0.02973516, -0.12817474, 0.23232058, 0.09585897],
                [0.14525947, 0.23093222, -0.10841238, -0.59855586],
                [-0.00722821, -0.08082656, -0.07108198, -0.06080972]
            ],
            [
                [-0.03708792, 0.01189943, 0.41607130, 0.03411148],
                [0.25580385, 0.35231194, 0.29753110, -0.54662442],
                [0.10137358, -0.16306859, -0.34208044, -0.08278930]
            ]
        ]
    );
    assert_close_to_literal!(
        g.get(&w),
        [
            [
                [[-0.25753322, 0.51525033], [0.74739242, 0.45198986]],
                [[-0.17857774, 0.20734756], [-0.27595735, 0.05209281]],
                [[0.06679222, 0.17222919], [0.03259417, -0.07203376]]
            ],
            [
                [[-0.58513254, -0.09418570], [0.07769967, -0.01389713]],
                [[0.20000815, -0.24702756], [-0.11653619, 0.06147216]],
                [[0.05383727, -0.05131300], [-0.12168837, -0.05695146]]
            ]
        ]
    );
}

#[test]
fn test_convtrans2d_batched() {
    let dev: TestDevice = Default::default();
    let x: Tensor<Rank3<3, 28, 28>, TestDtype, _> = dev.sample_normal();
    let w: Tensor<Rank4<3, 5, 6, 6>, TestDtype, _> = dev.sample_normal();

    let y: Tensor<Rank3<5, 83, 83>, _, _, _> =
        (x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>);
    let y0 = y.retaped::<NoneTape>();
    let grads0 = y.square().mean().backward();
    let x0 = grads0.get(&x);
    let w0 = grads0.get(&w);

    let x = x
        .broadcast::<Rank4<10, 3, 28, 28>, _>()
        .reshape::<Rank4<10, 3, 28, 28>>();

    let y: Tensor<Rank4<10, 5, 83, 83>, _, _, _> =
        (x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>);
    for i in 0..10 {
        assert_close_to_tensor!(y0, y.retaped::<NoneTape>().select(dev.tensor(i)), 1e-5);
    }

    let grads = y.square().mean().backward();

    assert_close_to_tensor!(w0, grads.get(&w));

    let x_grad = grads.get(&x) * 10.0;
    for i in 0..10 {
        assert_close_to_tensor!(x0, x_grad.clone().select(dev.tensor(i)));
    }
}

#[test]
fn test_convtrans2d_grouped() {
    let device = TestDevice::default();
    let x = device
        .tensor([
            [
                [-1.44135797, 0.23273671, 0.55838293, 1.09627271],
                [0.05642751, 0.96609902, 0.24707083, -1.48412001],
                [0.40077326, 1.16620362, 0.48770329, -1.01286852],
            ],
            [
                [0.32589695, -0.91695106, -0.13670059, 0.23979346],
                [0.72553939, -0.38209674, 0.35545620, -0.37955058],
                [-0.43962145, -0.69825196, -1.05400932, 1.22050178],
            ],
        ])
        .to_dtype::<TestDtype>();
    let w = device
        .tensor([
            [
                [[-0.46923456, -0.05046696], [0.30552489, 0.68838942]],
                [[-0.45128539, 0.28081128], [-1.64214528, -0.67228431]],
                [[0.82240576, 0.06074515], [-0.19825625, -0.28086993]],
            ],
            [
                [[-1.61626124, 0.44280547], [0.03486229, 0.80007958]],
                [[-0.69502026, 0.22409980], [0.58573663, -0.36501792]],
                [[1.57198501, 0.84847665], [0.52653229, -0.59273601]],
            ],
        ])
        .to_dtype::<TestDtype>();
    let y =
        (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<2>);
    #[rustfmt::skip]
    assert_close_to_literal!(
        y,
        [
            [
                [0.67633498,-0.03646715,-0.27375808,-0.54258895,-0.05532555],
                [-0.46684846, -1.37728357, 0.16612311, 1.40325499, 0.82956153],
                [-0.17081666,-0.23343745,0.45283663,0.16730537,-0.97053605],
                [0.12244620, 0.63219225, 0.95180768, 0.02627325, -0.69724798]
            ],
            [
                [0.65046382,-0.50978023,-0.18663496,-0.33793163,0.30784574],
                [2.34145427,0.16667402,-0.91361910,-1.43648922,-1.15376461],
                [-0.27352527,-2.03815913,-0.94782966,2.86508417,0.71332568],
                [-0.65812790,-2.18450928,-1.58490002,1.33540201,0.68093562]
            ],
            [
                [-1.18538105, 0.10384849, 0.47335497, 0.93550003, 0.06659325],
                [0.33216453, 1.15664566, 0.08580667, -1.57971644, -0.39806312],
                [0.31841114, 0.77605361, 0.15159971, -0.57852203, 0.35531783],
                [-0.07945581,-0.34377232,-0.42424178,0.06382634,0.28448433]
            ],
            [
                [-0.52673459,1.62634134,-0.18508710,-0.44810066,0.10618186],
                [-1.16129971, 1.16761744, -1.48210406, 0.66983926, 0.02378678],
                [0.73583704, 1.50105929, 1.10104883, -2.16820955, 0.23677418],
                [-0.01532621,-0.37607479,-0.59540236,-0.80074185,0.97649854]
            ],
            [
                [-0.22650498,0.71033299,-0.11047887,-0.19729590,0.05373767],
                [-0.31337482,-0.22789186,-0.07804406,0.53380698,-0.17258611],
                [0.73052084, -0.10186276, 0.92375565, -1.43654132, 0.41205698],
                [-0.25750238,-0.24852204,-0.36249739,1.09962487,-0.44550502]
            ],
            [
                [0.51230514, -1.16491735, -0.99290282, 0.26096445, 0.20345916],
                [1.31213224,-0.66102237,0.70610434,-0.08776516,-0.46417403],
                [-0.30905840,-2.10188961,-1.83569562,0.61377084,1.26054060],
                [-0.23147489,-0.10707274,-0.14109090,1.26738286,-0.72343534]
            ]
        ]
    );

    let g = y.exp().mean().backward();

    assert_close_to_literal!(
        g.get(&x),
        [
            [
                [-0.16682972, -0.01479191, 0.02486472, 0.03241257],
                [-0.03953645, 0.01812358, -0.09578598, -0.26596320],
                [0.00710792, 0.02639988, 0.02603059, -0.12393593]
            ],
            [
                [0.07047811, -0.07305276, 0.01832999, 0.03246089],
                [0.09772546, -0.00640025, 0.03977848, -0.01866792],
                [-0.00610692, -0.04190996, -0.05539178, 0.09744944]
            ]
        ]
    );
    assert_close_to_literal!(
        g.get(&w),
        [
            [
                [[-0.05153870, 0.01042829], [0.05493409, 0.08134348]],
                [[-0.14194693, 0.06882061], [-0.36182982, 0.00579517]],
                [[0.08088457, 0.02259952], [0.01361402, -0.03323809]]
            ],
            [
                [[-0.10977639, 0.00816289], [-0.02106360, 0.03803319]],
                [[-0.04609783, -0.00057952], [0.03170185, -0.03856671]],
                [[0.04222549, 0.01465542], [0.02470277, -0.05392574]]
            ]
        ]
    );
}

#[test]
fn test_convtrans2d_dilated() {
    let device = TestDevice::default();
    let x = device
        .tensor([
            [
                [-1.44135797, 0.23273671, 0.55838293, 1.09627271],
                [0.05642751, 0.96609902, 0.24707083, -1.48412001],
                [0.40077326, 1.16620362, 0.48770329, -1.01286852],
            ],
            [
                [0.32589695, -0.91695106, -0.13670059, 0.23979346],
                [0.72553939, -0.38209674, 0.35545620, -0.37955058],
                [-0.43962145, -0.69825196, -1.05400932, 1.22050178],
            ],
        ])
        .to_dtype::<TestDtype>();
    let w = device
        .tensor([
            [
                [[-0.46923456, -0.05046696], [0.30552489, 0.68838942]],
                [[-0.45128539, 0.28081128], [-1.64214528, -0.67228431]],
                [[0.82240576, 0.06074515], [-0.19825625, -0.28086993]],
            ],
            [
                [[-1.61626124, 0.44280547], [0.03486229, 0.80007958]],
                [[-0.69502026, 0.22409980], [0.58573663, -0.36501792]],
                [[1.57198501, 0.84847665], [0.52653229, -0.59273601]],
            ],
        ])
        .to_dtype::<TestDtype>();
    let y =
        (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<2>, Const::<1>);
    #[rustfmt::skip]
    assert_close_to_literal!(
        y,
        [
            [
                [0.14960037,1.37282431,0.17598119,-1.31975436,-0.08871166,0.05085630],
                [-1.19913888,0.16424109,-0.37201914,1.09190273,0.14492904,-0.09316804],
                [0.09347722,0.62047440,0.69417661,-2.09554338,-0.21632043,1.53807759],
                [0.04253398,0.28184652,0.70721161,-0.10732305,0.45447421,-1.32532310],
                [0.10711999,0.33196157,0.03641622,-0.02276197,-0.50756156,0.27925056]
            ],
            [
                [0.42395884,0.53226888,-0.48869652,-0.80152661,0.12616566,0.36158341],
                [-0.52972949,-0.17042139,-0.18011002,1.11922061,0.14903794,-0.50181484],
                [2.68249178,-0.96027076,0.37951475,-1.70172071,-0.42474461,-0.83544654],
                [0.33231282,-1.81028295,-0.50029266,1.70480287,-0.29584974,1.13629329],
                [-0.91563028,-2.32406759,-1.52721536,1.84902382,0.05685703,0.23543060]
            ],
            [
                [-0.67307591,-1.25002933,0.43328649,0.51465881,-0.08206820,0.27005240],
                [1.18694329,0.19387504,1.38099527,-2.08271098,0.31660464,-0.41219291],
                [0.09587342,-0.66749489,-1.57547712,0.95106959,-0.94048226,0.52399611],
                [0.37083280,-0.39272144,-0.30772626,0.04952475,-0.28008646,0.64181799],
                [-0.31093070,-0.59885937,-0.50364584,0.92976868,0.48776808,-0.43895102]
            ]
        ]
    );

    let g = y.exp().mean().backward();

    assert_close_to_literal!(
        g.get(&x),
        [
            [
                [-0.26850072, -0.03441733, -0.01182073, 0.03490706],
                [0.01392877, -0.02877949, 0.03081299, -0.15472011],
                [-0.06612058, -0.05083767, -0.02418187, -0.09756426]
            ],
            [
                [0.11071943, -0.07129460, 0.03023484, 0.08694579],
                [0.11894555, 0.00127397, 0.08722699, -0.04324878],
                [-0.08628573, -0.03130906, -0.03867241, 0.14042965]
            ]
        ]
    );
    assert_close_to_literal!(
        g.get(&w),
        [
            [
                [[0.00630402, -0.01686253], [0.02441918, 0.04648526]],
                [[0.02635447, 0.04999616], [-0.37556821, 0.07128908]],
                [[0.03001181, 0.01555148], [-0.00549905, 0.04595280]]
            ],
            [
                [[-0.09109048, 0.03960687], [-0.02581818, 0.03304695]],
                [[-0.10903731, -0.01481715], [0.11795966, -0.07065595]],
                [[0.06082013, 0.02204099], [0.03207079, -0.05822099]]
            ]
        ]
    );
}
